-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Add 'zeroTangentVector' property to 'Differentiable' protocol. #26521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally LGTM!
(There'll be merge conflicts with my WIP patch to remove AllDifferentiableVariables
, though no big deal.)
63f996a
to
e5aa12e
Compare
…col. Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The [current implementation](https://github.com/tensorflow/swift-apis/blob/master/Sources/TensorFlow/Optimizers/MomentumBased.swift) of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like `infinityNorm` are initialized as `.zero`). An earlier version of these optimizer using the deprecated `AllDifferentiableVariables` property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraint `TangentVector == AllDifferentiableVariables` to optimizers, and 2. make a copy of all parameters and resetting them to `.zero`. Since we are deprecating `AllDifferentiableVariables`, this is not the right direction. This problem also means that our `Differentiable` abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add a `zeroTangentVector` property to the `Differentiable` protocol. Zero tangent vectors do not have a canonical mathematical definition, but makes sense for `Differentiable` in the standard library because Swift does not have dependent types and thus cannot have a `TangentVector` that depends on a point on a differentiable manifold. Manopt also has an API, `M.zerovec(x)`, that creates a zero tangent vector at a point (see their API doc [here](https://www.manopt.org/tutorial.html). Adding `zeroTangentVector` will make it possible to deprecate `AllDifferentiableVariables` completely, because currently some fast.ai notebooks depend on initializing parameter gradients using `AllDifferentiableVariables`. The new `Differentiable` protocol looks like the following. The [design overview](http://bit.ly/swift-autodiff) has been updated to reflect this change. ```swift protocol Differentiable { /// A type representing the differentiable value’s derivatives. /// Mathematically, this is equivalent to the tangent bundle of the /// differentiable manifold represented by the differentiable type. associatedtype TangentVector: Differentiable & AdditiveArithmetic /// Moves `self` along the given direction. In Riemannian geometry, /// this is equivalent to exponential map, which moves `self` on the /// geodesic surface along the given tangent vector. mutating func move(along direction: TangentVector) /// A tangent vector such that `move(along: zeroTangentVector)` /// will not modify `self`. /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases, /// but types whose tangent vectors depend on instance properties of /// `self` need to provide a different implementation. For example, an /// array’s zero tangent vector depends on the array’s `count`. var zeroTangentVector: TangentVector { get } } ```
e5aa12e
to
6a334f7
Compare
@@ -152,6 +152,9 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric { | |||
/// A type that mathematically represents a differentiable manifold whose | |||
/// tangent spaces are finite-dimensional. | |||
public protocol Differentiable { | |||
/// A type representing a differentiable value’s derivatives. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big 👍 on these doc comments by the way 🙂
The API change part of this PR has been implemented as #26828 . |
6dcf239
to
04dca63
Compare
I accidentally deleted |
Closing this obsolete PR.
|
Zero tangent vector is necessary for optimizations on models with an array of parameters, especially for optimizers that iterates over parameters using key paths. The current implementation of some key-path-based optimizers is wrong in that it won't work with models that contain an array of parameters (tangent vectors like
infinityNorm
are initialized as.zero
).An earlier version of these optimizer using the deprecated
AllDifferentiableVariables
property would give the correct results, but would be heavyweight and inefficient because they'd need to 1. add a constraintTangentVector == AllDifferentiableVariables
to optimizers, and 2. make a copy of all parameters and resetting them to.zero
. Since we are deprecatingAllDifferentiableVariables
, this is not the right direction.This problem also means that our
Differentiable
abstraction needs to provide a general mechanism of obtaining a zero tangent vector at a certain instance. Hence we add azeroTangentVector
property to theDifferentiable
protocol.Zero tangent vectors do not have a canonical mathematical definition, but makes sense for
Differentiable
in the standard library because Swift does not have dependent types and thus cannot have aTangentVector
that depends on a point on a differentiable manifold. Manopt also has an API,M.zerovec(x)
, that creates a zero tangent vector at a point (see their API doc here).Adding
zeroTangentVector
will make it possible to deprecateAllDifferentiableVariables
completely, because currently some fast.ai notebooks depend on initializing parameter gradients usingAllDifferentiableVariables
.The new
Differentiable
protocol looks like the following. The design overview has been updated to reflect this change.Resolves TF-708.